{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and load model and artifacts\n",
    "In this notebook I will show the different options to save and load a model, as well as some additional objects produced during training. \n",
    "\n",
    "On a given day, you train a model..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import shutil\n",
    "\n",
    "from pytorch_widedeep.preprocessing import TabPreprocessor\n",
    "from pytorch_widedeep.training import Trainer\n",
    "from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint, LRHistory\n",
    "from pytorch_widedeep.models import TabMlp, WideDeep\n",
    "from pytorch_widedeep.metrics import Accuracy\n",
    "from pytorch_widedeep.datasets import load_adult\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational-num      marital-status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital-gain  capital-loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours-per-week native-country income  \n",
       "0              40  United-States  <=50K  \n",
       "1              50  United-States  <=50K  \n",
       "2              40  United-States   >50K  \n",
       "3              40  United-States   >50K  \n",
       "4              30  United-States  <=50K  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = load_adult(as_frame=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country  target  \n",
       "0              40  United-States       0  \n",
       "1              50  United-States       0  \n",
       "2              40  United-States       1  \n",
       "3              40  United-States       1  \n",
       "4              30  United-States       0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For convenience, we'll replace '-' with '_'\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "# binary target\n",
    "df[\"target\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop(\"income\", axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, valid = train_test_split(df, test_size=0.2, stratify=df.target)\n",
    "# the test data will be used lately as if it was \"fresh\", new data coming after some time...\n",
    "valid, test = train_test_split(valid, test_size=0.5, stratify=valid.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train shape: (39073, 15)\n",
      "valid shape: (4884, 15)\n",
      "test shape: (4885, 15)\n"
     ]
    }
   ],
   "source": [
    "print(f\"train shape: {train.shape}\")\n",
    "print(f\"valid shape: {valid.shape}\")\n",
    "print(f\"test shape: {test.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_embed_cols = [\n",
    "    \"workclass\",\n",
    "    \"education\",\n",
    "    \"marital_status\",\n",
    "    \"occupation\",\n",
    "    \"relationship\",\n",
    "    \"race\",\n",
    "    \"gender\",\n",
    "    \"capital_gain\",\n",
    "    \"capital_loss\",\n",
    "    \"native_country\",\n",
    "]\n",
    "continuous_cols = [\"age\", \"hours_per_week\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised\n",
      "  warnings.warn(\"Continuous columns will not be normalised\")\n"
     ]
    }
   ],
   "source": [
    "tab_preprocessor = TabPreprocessor(\n",
    "    embed_cols=cat_embed_cols,\n",
    "    continuous_cols=continuous_cols,\n",
    ")\n",
    "X_tab_train = tab_preprocessor.fit_transform(train)\n",
    "y_train = train.target.values\n",
    "X_tab_valid = tab_preprocessor.transform(valid)\n",
    "y_valid = valid.target.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    cont_norm_layer=\"layernorm\",\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=8,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "model = WideDeep(deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (deeptabular): Sequential(\n",
       "    (0): TabMlp(\n",
       "      (cat_embed): DiffSizeCatEmbeddings(\n",
       "        (embed_layers): ModuleDict(\n",
       "          (emb_layer_workclass): Embedding(10, 5, padding_idx=0)\n",
       "          (emb_layer_education): Embedding(17, 8, padding_idx=0)\n",
       "          (emb_layer_marital_status): Embedding(8, 5, padding_idx=0)\n",
       "          (emb_layer_occupation): Embedding(16, 7, padding_idx=0)\n",
       "          (emb_layer_relationship): Embedding(7, 4, padding_idx=0)\n",
       "          (emb_layer_race): Embedding(6, 4, padding_idx=0)\n",
       "          (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
       "          (emb_layer_capital_gain): Embedding(122, 23, padding_idx=0)\n",
       "          (emb_layer_capital_loss): Embedding(97, 21, padding_idx=0)\n",
       "          (emb_layer_native_country): Embedding(43, 13, padding_idx=0)\n",
       "        )\n",
       "        (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (cont_norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)\n",
       "      (cont_embed): ContEmbeddings(\n",
       "        INFO: [ContLinear = weight(n_cont_cols, embed_dim) + bias(n_cont_cols, embed_dim)]\n",
       "        (linear): ContLinear(n_cont_cols=2, embed_dim=8, embed_dropout=0.0)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (encoder): MLP(\n",
       "        (mlp): Sequential(\n",
       "          (dense_layer_0): Sequential(\n",
       "            (0): Linear(in_features=108, out_features=64, bias=True)\n",
       "            (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (2): Dropout(p=0.2, inplace=False)\n",
       "          )\n",
       "          (dense_layer_1): Sequential(\n",
       "            (0): Linear(in_features=64, out_features=32, bias=True)\n",
       "            (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "            (2): Dropout(p=0.2, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=32, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 76.25it/s, loss=0.452, metrics={'acc': 0.7867}]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 125.36it/s, loss=0.335, metrics={'acc': 0.8532}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1: val_loss improved from inf to 0.33532 Saving model to tmp_dir/adult_tabmlp_model_1.p\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 2: 100%|██████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 76.98it/s, loss=0.355, metrics={'acc': 0.8401}]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 106.51it/s, loss=0.303, metrics={'acc': 0.8665}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 2: val_loss improved from 0.33532 to 0.30273 Saving model to tmp_dir/adult_tabmlp_model_2.p\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 3: 100%|███████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 82.71it/s, loss=0.332, metrics={'acc': 0.849}]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 107.80it/s, loss=0.288, metrics={'acc': 0.8757}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 3: val_loss improved from 0.30273 to 0.28791 Saving model to tmp_dir/adult_tabmlp_model_3.p\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 4: 100%|███████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 79.02it/s, loss=0.32, metrics={'acc': 0.8541}]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 127.07it/s, loss=0.282, metrics={'acc': 0.8763}]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 4: val_loss improved from 0.28791 to 0.28238 Saving model to tmp_dir/adult_tabmlp_model_4.p\n",
      "Model weights restored to best epoch: 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "early_stopping = EarlyStopping()\n",
    "model_checkpoint = ModelCheckpoint(\n",
    "    filepath=\"tmp_dir/adult_tabmlp_model\",\n",
    "    save_best_only=True,\n",
    "    verbose=1,\n",
    "    max_save=1,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    "    callbacks=[early_stopping, model_checkpoint],\n",
    "    metrics=[Accuracy],\n",
    ")\n",
    "\n",
    "trainer.fit(\n",
    "    X_train={\"X_tab\": X_tab_train, \"target\": y_train},\n",
    "    X_val={\"X_tab\": X_tab_valid, \"target\": y_valid},\n",
    "    n_epochs=4,\n",
    "    batch_size=256,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save model: option 1\n",
    "\n",
    "save (and load) a model as you woud do with any other torch model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, \"tmp_dir/model_saved_option_1.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), \"tmp_dir/model_state_dict_saved_option_1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save model: option 2\n",
    "\n",
    "use the `trainer`. The `trainer` will also save the training history and the learning rate history (if learning rate schedulers are used)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save(path=\"tmp_dir/\", model_filename=\"model_saved_option_2.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "or the state dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save(\n",
    "    path=\"tmp_dir/\",\n",
    "    model_filename=\"model_state_dict_saved_option_2.pt\",\n",
    "    save_state_dict=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "adult_tabmlp_model_4.p\n",
      "\u001b[1m\u001b[34mhistory\u001b[m\u001b[m\n",
      "model_saved_option_1.pt\n",
      "model_saved_option_2.pt\n",
      "model_state_dict_saved_option_1.pt\n",
      "model_state_dict_saved_option_2.pt\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "ls tmp_dir/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_eval_history.json\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "ls tmp_dir/history/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that since we have used the `ModelCheckpoint` Callback, `adult_tabmlp_model_2.p` is the model state dict of the model at epoch 2, i.e. same as `model_state_dict_saved_option_1.p` or `model_state_dict_saved_option_2.p`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save preprocessors and callbacks\n",
    "\n",
    "...just pickle them"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"tmp_dir/tab_preproc.pkl\", \"wb\") as dp:\n",
    "    pickle.dump(tab_preprocessor, dp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"tmp_dir/eary_stop.pkl\", \"wb\") as es:\n",
    "    pickle.dump(early_stopping, es)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "adult_tabmlp_model_4.p\n",
      "eary_stop.pkl\n",
      "\u001b[1m\u001b[34mhistory\u001b[m\u001b[m\n",
      "model_saved_option_1.pt\n",
      "model_saved_option_2.pt\n",
      "model_state_dict_saved_option_1.pt\n",
      "model_state_dict_saved_option_2.pt\n",
      "tab_preproc.pkl\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "ls tmp_dir/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And that is pretty much all you need to resume training or directly predict, let's see"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run New experiment: prepare new dataset, load model, and predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10103</th>\n",
       "      <td>43</td>\n",
       "      <td>Private</td>\n",
       "      <td>198282</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Craft-repair</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31799</th>\n",
       "      <td>20</td>\n",
       "      <td>Private</td>\n",
       "      <td>228686</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Other-service</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19971</th>\n",
       "      <td>26</td>\n",
       "      <td>Private</td>\n",
       "      <td>291968</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Transport-moving</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>44</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3039</th>\n",
       "      <td>48</td>\n",
       "      <td>Private</td>\n",
       "      <td>175958</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20725</th>\n",
       "      <td>18</td>\n",
       "      <td>Private</td>\n",
       "      <td>232024</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>55</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       age workclass  fnlwgt  education  educational_num      marital_status  \\\n",
       "10103   43   Private  198282    HS-grad                9  Married-civ-spouse   \n",
       "31799   20   Private  228686       11th                7  Married-civ-spouse   \n",
       "19971   26   Private  291968    HS-grad                9  Married-civ-spouse   \n",
       "3039    48   Private  175958  Bachelors               13            Divorced   \n",
       "20725   18   Private  232024       11th                7       Never-married   \n",
       "\n",
       "              occupation   relationship   race gender  capital_gain  \\\n",
       "10103       Craft-repair        Husband  White   Male             0   \n",
       "31799      Other-service        Husband  White   Male             0   \n",
       "19971   Transport-moving        Husband  White   Male             0   \n",
       "3039      Prof-specialty  Not-in-family  White   Male             0   \n",
       "20725  Machine-op-inspct      Own-child  White   Male             0   \n",
       "\n",
       "       capital_loss  hours_per_week native_country  target  \n",
       "10103             0              40  United-States       1  \n",
       "31799             0              40  United-States       0  \n",
       "19971             0              44  United-States       0  \n",
       "3039              0              30  United-States       0  \n",
       "20725             0              55  United-States       0  "
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"tmp_dir/tab_preproc.pkl\", \"rb\") as tp:\n",
    "    tab_preprocessor_new = pickle.load(tp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test_tab = tab_preprocessor_new.transform(test)\n",
    "y_test = test.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_mlp_new = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    cat_embed_dropout=0.1,\n",
    "    continuous_cols=continuous_cols,\n",
    "    cont_norm_layer=\"layernorm\",\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=8,\n",
    "    mlp_hidden_dims=[64, 32],\n",
    "    mlp_dropout=0.2,\n",
    "    mlp_activation=\"leaky_relu\",\n",
    ")\n",
    "new_model = WideDeep(deeptabular=tab_mlp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model.load_state_dict(torch.load(\"tmp_dir/model_state_dict_saved_option_2.pt\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    objective=\"binary\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "predict: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:00<00:00, 309.83it/s]\n"
     ]
    }
   ],
   "source": [
    "preds = trainer.predict(X_tab=X_test_tab, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8595701125895598"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree(\"tmp_dir/\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
